# -*- coding:utf-8 -*-

from __future__ import division
import tensorflow as tf
from tqdm import trange
from utils.io_utils import conveter_model_path
from utils.log_utils import ClampLog, make_summary, log_debug
from utils.tb_utils import run_tensorboard
from eval.eval import get_accuracy
from model.loss.loss import get_loss
from model.optimzer.learning_rate import get_learning_rate
from model.optimzer.optimizer import get_optimizer
from config.glob.global_pool import global_pool

"""
model
"""


class Model:
    def __init__(self):
        self.sess = tf.Session()
        self.xs = None
        self.ys = None
        self.y_pred = None
        self.in_predict = None  # 预测模块
        self.dataset = global_pool.dataset_op
        self.creat_xs_ys()

    def creat_xs_ys(self):
        """
        生成xs和ys
        :return:
        """
        xs_shape = [None]
        ys_shape = [None]
        for x_rank in global_pool.config.xs_shape:
            xs_shape.append(x_rank)
        for y_rank in global_pool.config.ys_shape:
            ys_shape.append(y_rank)

        xs_dtype = tf.int64 if global_pool.config.xs_dtype == 'int' else tf.float32
        ys_dtype = tf.int64 if global_pool.config.ys_dtype == 'int' else tf.float32

        self.xs = tf.placeholder(xs_dtype, xs_shape, name='xs')
        self.ys = tf.placeholder(ys_dtype, ys_shape, name='ys')

    def train(self):
        """
        反向传播训练
        :return:
        """
        saver = None
        global_step = tf.Variable(0, trainable=False)
        if global_pool.config.save.is_save:
            with ClampLog("building saver"):
                saver = tf.train.Saver(max_to_keep=1)

        with ClampLog('building loss'), tf.name_scope('loss'):  # 损失函数
            loss = get_loss(self.ys, self.y_pred)
            tf.summary.scalar('loss', loss)

        with tf.name_scope('accuracy'):  # 正确率
            accuracy = get_accuracy(self.ys, self.y_pred)

        with ClampLog('building optimizer'), tf.name_scope('train'):  # 学习率和优化器
            learning_rate = get_learning_rate(global_step, global_pool.config.learning_rate)
            tf.summary.scalar('learning_rate', learning_rate)
            train_step = get_optimizer(global_pool.config, learning_rate, loss, global_step)

        tf.global_variables_initializer().run(session=self.sess)
        with ClampLog('saving board log'):
            merged = tf.summary.merge_all()
            # 从config中解析出log_dir, save_model_dir
            log_dir, save_model_dir = conveter_model_path(global_pool.config)
            summary_writer = tf.summary.FileWriter(log_dir, self.sess.graph)

        run_tensorboard(log_dir)   # 运行tensorboard, 自动弹出

        for epoch in range(global_pool.config.epoch):
            self.dataset.train.init_itreator()
            with trange(self.dataset.train.batch_num, desc='\033[33m train   ') as t:
                for _ in t:
                    x, y = self.dataset.train.next_batch()
                    _, __accuracy, __loss, __learning_rate, summary_str, __global_step = self.sess.run(
                        [train_step, accuracy,  loss, learning_rate,  merged, global_step],
                        feed_dict={self.xs: x, self.ys: y}
                    )
                    summary_writer.add_summary(
                        make_summary('train/accuracy', __accuracy), global_step=__global_step
                    )
                    summary_writer.add_summary(summary_str, __global_step)
                    info = 'epoch:{}, loss:{:.4f}, lr:{:.5f}, accuracy:{:.4f}, steps:{}'\
                        .format(epoch, __loss, __learning_rate, __accuracy, __global_step)
                    t.set_postfix_str(info)

            if self.dataset.validation.dataset and (epoch + 1) % 1 == 0:  # 如果存在且满足条件，每2个epoch验证集验证
                self.validate(self.sess, epoch, accuracy, summary_writer)

        if self.dataset.test.dataset:  # 如果测试集存在的话
            self.test(self.sess, accuracy, summary_writer)

        if global_pool.config.save.is_save:
            with ClampLog('saving model'):
                saver.save(self.sess, save_model_dir)
        self.sess.close()

    def validate(self, sess, epoch, accuracy, summary_writer):
        """
        验证集评估
        :param sess:
        :param epoch:
        :param accuracy:
        :param summary_writer:
        :return:
        """
        self.dataset.validation.init_itreator()
        x, y = self.dataset.validation.next_batch()
        __accuracy = sess.run(accuracy, feed_dict={self.xs: x, self.ys: y})
        summary_writer.add_summary(make_summary('validation/accuracy', __accuracy), global_step=epoch)
        info = '\033[34mvalidate  epoch:{}, precision:{:.4f}\033[0m'.format(epoch, __accuracy)
        log_debug(info)

    def test(self, sess, accuracy, summary_writer):
        """
        测试集测试
        :param sess:
        :param accuracy:
        :param summary_writer:
        :return:
        """
        with ClampLog('test'):
            self.dataset.test.init_itreator()
            x, y = self.dataset.test.next_batch()
            __accuracy = sess.run(accuracy, feed_dict={self.xs: x, self.ys: y})
            summary_writer.add_summary(make_summary('test/accuracy', __accuracy), global_step=0)
            info = '\033[35mtest  precision:{:.4f}\033[0m'.format(__accuracy)
            log_debug(info)


